package auth

import (
	"context"
	"gitee.com/lipore/plume/errors"
	"gitee.com/lipore/plume/logger"
	"github.com/gin-gonic/gin"
	"net/http"
	"time"
)

type User interface {
	VerifyPassword(password string) bool
	Id() interface{}
	Name() string
	Account() string
	Accesses() []Access
}

type NEndpoints struct {
	encoder              *tokenEncoder
	decoder              *tokenDecoder
	dataSource           DataSource
	cancelAutoRefreshKey func()
}

func NewNEndpoints(ctx context.Context, keystore PublicKeyStore, authDs DataSource, options *NEndpointOptions) *NEndpoints {
	encoderOptions := tokenEncoderOptions{
		maxKeyAge: options.MaxKeyAge,
	}
	return &NEndpoints{
		dataSource: authDs,
		encoder:    newTokenEncoder(ctx, keystore, encoderOptions),
		decoder:    newTokenDecoder(keystore),
	}
}

type NEndpointOptions struct {
	MaxKeyAge time.Duration
}

func NewNEndpointOptions() *NEndpointOptions {
	return &NEndpointOptions{
		MaxKeyAge: 30 * 24 * time.Hour,
	}
}

type NRequest struct {
	Account  string `form:"account"`
	Password string `form:"password"`
}

type NResponse struct {
	AccessToken     string `json:"access_token"`
	RefreshToken    string `json:"refresh_token"`
	ExpireAt        int64  `json:"expire_at"`
	RefreshExpireAt int64  `json:"refresh_expire_at"`
}

func (n *NEndpoints) LoginEndpoint(c *gin.Context) {
	body := NRequest{}
	if err := c.ShouldBind(&body); err == nil {
	} else if err := c.ShouldBindJSON(&body); err == nil {
	} else {
		logger.Warnf("%v", errors.WithMessage(err, "can't bind request body"))
		c.AbortWithStatus(http.StatusBadRequest)
		return
	}
	if body.Account == "" || body.Password == "" {
		logger.Warnf("%v", errors.New("user or password is empty"))
		c.AbortWithStatus(http.StatusBadRequest)
		return
	}
	err, accessToken, refreshToken := n.authUser(c, body.Account, body.Password)
	if err != nil {
		logger.Warnf("%v", err)
		c.AbortWithStatus(http.StatusUnauthorized)
		return
	}
	atString, err := n.encoder.encode(accessToken)
	if err != nil {
		err = errors.WithMessage(err, "create access token failed")
		logger.Warnf("%v", err)
		c.AbortWithStatus(http.StatusInternalServerError)
		return
	}
	rtString, err := n.encoder.encode(refreshToken)
	if err != nil {
		err = errors.WithMessage(err, "create refresh token failed")
		logger.Warnf("%v", err)
		c.AbortWithStatus(http.StatusInternalServerError)
		return
	}
	c.JSON(http.StatusOK, &NResponse{
		AccessToken:     atString,
		RefreshToken:    rtString,
		ExpireAt:        accessToken.ExpiresAt,
		RefreshExpireAt: refreshToken.ExpiresAt,
	})
}

type NRefreshRequest struct {
	RefreshToken string `form:"refresh_token" json:"refresh_token"`
}

func (n *NEndpoints) RefreshTokenEndpoint(c *gin.Context) {
	body := NRefreshRequest{}
	if err := c.ShouldBind(&body); err == nil {
	} else if err := c.ShouldBindJSON(&body); err == nil {
	} else {
		logger.Warnf("%v", errors.WithMessage(err, "can't bind request body"))
		c.AbortWithStatus(http.StatusBadRequest)
		return
	}
	if body.RefreshToken == "" {
		logger.Warnf("%v", errors.New("refresh token is empty"))
		c.AbortWithStatus(http.StatusBadRequest)
		return
	}
	refreshTokenClaims := RefreshTokenClaims{}
	err := n.decoder.Decode(body.RefreshToken, &refreshTokenClaims)
	if err != nil {
		logger.Warnf("%v", err)
		c.AbortWithStatus(http.StatusUnauthorized)
		return
	}
	err, accessToken, refreshToken := n.authRefreshToken(c, &refreshTokenClaims)
	if err != nil {
		logger.Warnf("%v", err)
		c.AbortWithStatus(http.StatusUnauthorized)
		return
	}
	atString, err := n.encoder.encode(accessToken)
	if err != nil {
		err = errors.WithMessage(err, "create access token failed")
		logger.Warnf("%v", err)
		c.AbortWithStatus(http.StatusInternalServerError)
		return
	}
	rtString, err := n.encoder.encode(refreshToken)
	if err != nil {
		err = errors.WithMessage(err, "create refresh token failed")
		logger.Warnf("%v", err)
		c.AbortWithStatus(http.StatusInternalServerError)
		return
	}
	c.JSON(http.StatusOK, &NResponse{
		AccessToken:     atString,
		RefreshToken:    rtString,
		ExpireAt:        accessToken.ExpiresAt,
		RefreshExpireAt: refreshToken.ExpiresAt,
	})
}

func (n *NEndpoints) RefreshKey(ctx context.Context) {
	n.encoder.refreshSigningKey(ctx)
}

func (n *NEndpoints) authUser(ctx context.Context, userAccount string, password string) (error, *AccessTokenClaims, *RefreshTokenClaims) {
	user := n.dataSource.FetchUser(ctx, userAccount)
	if user != nil && user.VerifyPassword(password) {
		accesses := user.Accesses()
		accessToken := newAccessTokenClaims(user, accesses)
		refreshToken := newRefreshTokenClaims(user)
		return nil, accessToken, refreshToken
	} else {
		return errors.New("user/password not match"), nil, nil
	}
}

func (n *NEndpoints) authRefreshToken(ctx context.Context, claims *RefreshTokenClaims) (error, *AccessTokenClaims, *RefreshTokenClaims) {
	user := n.dataSource.FetchUser(ctx, claims.Account)
	if user.Id() == claims.UserId {
		accesses := user.Accesses()
		accessToken := newAccessTokenClaims(user, accesses)
		refreshToken := newRefreshTokenClaims(user)
		return nil, accessToken, refreshToken
	} else {
		return errors.New("invalid refresh token"), nil, nil
	}
}

func NMiddleware(store PublicKeyStore) func(c *gin.Context) {
	decoder := newTokenDecoder(store)

	return func(c *gin.Context) {
		path := c.Request.URL.Path
		var userId interface{}
		var accesses []Access
		if accessToken := parseAccessToken(c); accessToken != "" {
			accessTokenClaims := AccessTokenClaims{}
			err := decoder.Decode(accessToken, &accessTokenClaims)
			if err == nil {
				userId = accessTokenClaims.UserId
				accesses = accessTokenClaims.Accesses
			}
		}
		logger.Infof("user: %d, requested: %s", userId, path)
		c.Set(UserIdXAuthKey, userId)
		c.Set(AccessesXAuthKey, accesses)
		c.Next()
	}
}

func parseAccessToken(c *gin.Context) string {
	if accessToken := c.GetHeader("Authorization"); accessToken != "" {
		return accessToken[len("bearer "):]
	} else if accessToken, _ = c.Cookie("access"); accessToken != "" {
		return accessToken
	} else {
		return ""
	}
}
